{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d7aed882",
   "metadata": {},
   "source": [
    "# subroutines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "02cb2381",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "import tvm.testing\n",
    "from tvm import relax\n",
    "from tvm.ir import assert_structural_equal\n",
    "from tvm.relax.frontend import nn\n",
    "from tvm.script import ir as I\n",
    "from tvm.script import relax as R"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0c9553b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Activation(nn.Module):\n",
    "    define_subroutine = True\n",
    "\n",
    "    def forward(self, state: relax.Expr) -> relax.Var:\n",
    "        return nn.op.silu(state)\n",
    "\n",
    "class Layer(nn.Module):\n",
    "    define_subroutine = True\n",
    "\n",
    "    def __init__(self, in_features, out_features):\n",
    "        self.weights = nn.Parameter((in_features, out_features), dtype=\"float32\")\n",
    "        self.activation = Activation()\n",
    "\n",
    "    def forward(self, input: relax.Expr) -> relax.Var:\n",
    "        state = nn.op.matmul(input, self.weights)\n",
    "        return self.activation(state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "40b6a862",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">_initialize_effect</span>() <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Object):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            _io: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Object <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>null_value()\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Object) <span style=\"color: #AA22FF; font-weight: bold\">=</span> (_io,)\n",
       "            gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Object) <span style=\"color: #AA22FF; font-weight: bold\">=</span> lv\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function(private<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">True</span>)\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">activation</span>(state: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #BA2121\">&quot;batch_size&quot;</span>, <span style=\"color: #008000\">32</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #BA2121\">&quot;batch_size&quot;</span>, <span style=\"color: #008000\">32</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        batch_size <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64()\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            silu: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((batch_size, <span style=\"color: #008000\">32</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>silu(state)\n",
       "            gv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((batch_size, <span style=\"color: #008000\">32</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> silu\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv1)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv1\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">forward</span>(input: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #BA2121\">&quot;batch_size&quot;</span>, <span style=\"color: #008000\">64</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), _io: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Object, weights: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">32</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #BA2121\">&quot;batch_size&quot;</span>, <span style=\"color: #008000\">32</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Object)):\n",
       "        batch_size <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64()\n",
       "        R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;num_input&quot;</span>: <span style=\"color: #008000\">2</span>})\n",
       "        cls <span style=\"color: #AA22FF; font-weight: bold\">=</span> Module\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            layer_output: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((batch_size, <span style=\"color: #008000\">32</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>layer(input, weights)\n",
       "            gv3: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((batch_size, <span style=\"color: #008000\">32</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Object)) <span style=\"color: #AA22FF; font-weight: bold\">=</span> layer_output, (_io,)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv3)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv3\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function(private<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">True</span>)\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">layer</span>(input: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #BA2121\">&quot;batch_size&quot;</span>, <span style=\"color: #008000\">64</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weights: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">32</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #BA2121\">&quot;batch_size&quot;</span>, <span style=\"color: #008000\">32</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        batch_size <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64()\n",
       "        cls <span style=\"color: #AA22FF; font-weight: bold\">=</span> Module\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            matmul: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((batch_size, <span style=\"color: #008000\">32</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>matmul(input, weights, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            activation_output: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((batch_size, <span style=\"color: #008000\">32</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>activation(matmul)\n",
       "            gv2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((batch_size, <span style=\"color: #008000\">32</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> activation_output\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv2)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv2\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mod = Layer(64, 32)\n",
    "batch_size = tvm.tir.Var(\"batch_size\", \"int64\")\n",
    "tvm_mod, _ = mod.export_tvm(\n",
    "    spec={\"forward\": {\"input\": nn.spec.Tensor((batch_size, 64), \"float32\")}}, debug=True\n",
    ")\n",
    "tvm_mod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "017eabf2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
